﻿using Singer.Core;
using System.Linq.Expressions;

namespace Singer.Middleware.EFCore;

/// <summary>
/// EF Core 仓储基类
/// </summary>
/// <typeparam name="TDbContext">DbContext类型</typeparam>
/// <typeparam name="TEntity">实体类型</typeparam>
public abstract class BaseRepository<TDbContext, TEntity> : IBaseRepository<TDbContext, TEntity>
    where TDbContext : EFCoreDbContext
    where TEntity : class, IEntity
{
    public TDbContext DbContext { get; }
    public DbSet<TEntity> Table { get; }

    public BaseRepository(TDbContext dbContext)
    {
        DbContext = dbContext;
        Table = dbContext.Set<TEntity>();
    }

    public void Attach(TEntity entity) => DbContext.Attach(entity);
    public void Attach(IEnumerable<TEntity> entities) => DbContext.AttachRange(entities);

    public void RemoveAttach(TEntity entity) => DbContext.Entry(entity).State = EntityState.Unchanged;
    public void RemoveAttach(IEnumerable<TEntity> entities)
    {
        foreach (var entity in entities)
        {
            DbContext.Entry(entity).State = EntityState.Unchanged;
        }
    }

    public int SaveChanges() => DbContext.SaveChanges();
    public Task<int> SaveChangesAsync() => DbContext.SaveChangesAsync();

    public int Add(TEntity entity, bool saveChanges = true)
    {
        DbContext.Add(entity);
        if (saveChanges)
            return DbContext.SaveChanges();
        return 0;
    }
    public Task<int> AddAsync(TEntity entity, bool saveChanges = true)
    {
        DbContext.Add(entity);
        if (saveChanges)
            return DbContext.SaveChangesAsync();
        return Task.FromResult(0);
    }
    public int Add(IEnumerable<TEntity> entities, bool saveChanges = true)
    {
        foreach (var entiry in entities)
        {
            DbContext.Add(entiry);
        }
        if (saveChanges)
            return DbContext.SaveChanges();
        return 0;
    }
    public Task<int> AddAsync(IEnumerable<TEntity> entities, bool saveChanges = true)
    {
        foreach (var entity in entities)
        {
            DbContext.Add(entity);
        }
        if (saveChanges)
            return DbContext.SaveChangesAsync();
        return Task.FromResult(0);
    }

    public int Update(TEntity entity, bool saveChanges = true)
    {
        DbContext.Update(entity);
        if (saveChanges)
            return SaveChanges();
        return 0;
    }
    public Task<int> UpdateAsync(TEntity entity, bool saveChanges = true)
    {
        DbContext.Update(entity);
        if (saveChanges)
            return SaveChangesAsync();
        return Task.FromResult(0);
    }
    public int Update(IEnumerable<TEntity> entities, bool saveChanges = true)
    {
        foreach (var entity in entities)
        {
            DbContext.Update(entity);
        }
        if (saveChanges)
            return SaveChanges();
        return 0;
    }
    public Task<int> UpdateAsync(IEnumerable<TEntity> entities, bool saveChanges = true)
    {
        foreach (var entity in entities)
        {
            DbContext.Update(entity);
        }
        if (saveChanges)
            return SaveChangesAsync();
        return Task.FromResult(0);
    }

    public int Delete(TEntity entity, bool saveChanges = true)
    {
        DbContext.Remove(entity);
        if (saveChanges)
            return DbContext.SaveChanges();
        return 0;
    }
    public Task<int> DeleteAsync(TEntity entity, bool saveChanges = true)
    {
        DbContext.Remove(entity);
        if (saveChanges)
            return DbContext.SaveChangesAsync();
        return Task.FromResult(0);
    }
    public int Delete(IEnumerable<TEntity> entities, bool saveChanges = true)
    {
        DbContext.RemoveRange(entities);
        if (saveChanges)
            return DbContext.SaveChanges();
        return 0;
    }
    public Task<int> DeleteAsync(IEnumerable<TEntity> entities, bool saveChanges = true)
    {
        DbContext.RemoveRange(entities);
        if (saveChanges)
            return DbContext.SaveChangesAsync();
        return Task.FromResult(0);
    }
    public int Delete(Expression<Func<TEntity, bool>> expWhere)
    {
        expWhere.CheckNotNull(nameof(expWhere));
        var entities = List(expWhere);
        if (entities?.Count() > 0)
            DbContext.RemoveRange(entities);
        return DbContext.SaveChanges();
    }
    public async Task<int> DeleteAsync(Expression<Func<TEntity, bool>> expWhere)
    {
        expWhere.CheckNotNull(nameof(expWhere));
        var entities = await ListAsync(expWhere);
        if (entities?.Count() > 0)
        {
            DbContext.RemoveRange(entities);
            return await DbContext.SaveChangesAsync();
        }
        return 0;
    }

    public TEntity? Get(Expression<Func<TEntity, bool>>? expWhere = null, bool noTracking = false)
    {
        var query = Table.AsQueryable();
        if (noTracking)
            query = query.AsNoTracking();
        if (expWhere == null)
            return query.FirstOrDefault();
        return query.FirstOrDefault(expWhere);
    }
    public Task<TEntity?> GetAsync(Expression<Func<TEntity, bool>>? expWhere = null, bool noTracking = false)
    {
        var query = Table.AsQueryable();
        if (noTracking)
            query = query.AsNoTracking();
        if (expWhere == null)
            return query.FirstOrDefaultAsync();
        return query.FirstOrDefaultAsync(expWhere);
    }

    public int Count(Expression<Func<TEntity, bool>>? expWhere = null)
    {
        if (expWhere == null)
            return Table.AsNoTracking().Count();
        return Table.AsNoTracking().Count(expWhere);
    }
    public Task<int> CountAsync(Expression<Func<TEntity, bool>>? expWhere = null)
    {
        if (expWhere == null)
            return Table.AsNoTracking().CountAsync();
        return Table.AsNoTracking().CountAsync(expWhere);
    }

    public bool Any(Expression<Func<TEntity, bool>>? expWhere = null)
    {
        if (expWhere == null)
            return Table.AsNoTracking().Any();
        return Table.AsNoTracking().Any(expWhere);
    }
    public Task<bool> AnyAsync(Expression<Func<TEntity, bool>>? expWhere = null)
    {
        if (expWhere == null)
            return Table.AsNoTracking().AnyAsync();
        return Table.AsNoTracking().AnyAsync(expWhere);
    }

    public T? Max<T>(Expression<Func<TEntity, T>> maxSelector, Expression<Func<TEntity, bool>>? expWhere = null)
    {
        var query = Table.AsNoTracking();
        if (expWhere != null)
            query = query.Where(expWhere);
        return query.Max(maxSelector);
    }
    public Task<T?> MaxAsync<T>(Expression<Func<TEntity, T>> maxSelector, Expression<Func<TEntity, bool>>? expWhere = null)
    {
        var query = Table.AsNoTracking();
        if (expWhere != null)
            query = query.Where(expWhere);
        return query.MaxAsync(maxSelector);
    }

    public T? Min<T>(Expression<Func<TEntity, T>> minSelector, Expression<Func<TEntity, bool>>? expWhere = null)
    {
        var query = Table.AsNoTracking();
        if (expWhere != null)
            query = query.Where(expWhere);
        return query.Min(minSelector);
    }
    public Task<T> MinAsync<T>(Expression<Func<TEntity, T>> minSelector, Expression<Func<TEntity, bool>>? expWhere = null)
    {
        var query = Table.AsNoTracking();
        if (expWhere != null)
            query = query.Where(expWhere);
        return query.MinAsync(minSelector);
    }

    public T? SelectValue<T>(Expression<Func<TEntity, T>> selector, Expression<Func<TEntity, bool>>? expWhere = null)
    {
        var query = Table.AsNoTracking();
        if (expWhere != null)
            query = query.Where(expWhere);
        return query.Select(selector).FirstOrDefault();
    }
    public Task<T?> SelectValueAsync<T>(Expression<Func<TEntity, T>> selector, Expression<Func<TEntity, bool>>? expWhere = null)
    {
        var query = Table.AsNoTracking();
        if (expWhere != null)
            query = query.Where(expWhere);
        return query.Select(selector).FirstOrDefaultAsync();
    }

    public List<T> SelectValueList<T>(Expression<Func<TEntity, T>> selector, Expression<Func<TEntity, bool>>? expWhere = null)
    {
        var query = Table.AsNoTracking();
        if (expWhere != null)
            query = query.Where(expWhere);
        return query.Select(selector).ToList();
    }
    
    public Task<List<T>> SelectValueListAsync<T>(Expression<Func<TEntity, T>> selector, Expression<Func<TEntity, bool>>? expWhere = null)
    {
        var query = Table.AsNoTracking();
        if (expWhere != null)
            query = query.Where(expWhere);
        return query.Select(selector).ToListAsync();
    }

    public List<TEntity> List(Expression<Func<TEntity, bool>>? expWhere = null, bool noTracking = false)
    {
        var query = Table.AsQueryable();
        if (noTracking)
            query = query.AsNoTracking();
        if (expWhere == null)
            return query.ToList();
        return query.Where(expWhere).ToList();
    }
    public Task<List<TEntity>> ListAsync(Expression<Func<TEntity, bool>>? expWhere = null, bool noTracking = false)
    {
        var query = Table.AsQueryable();
        if (noTracking)
            query = query.AsNoTracking();
        if (expWhere == null)
            return query.ToListAsync();
        return query.Where(expWhere).ToListAsync();
    }
}
